#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import sys
import json
import copy
import glob
import time
import queue
import shutil
import serial
import argparse
import threading
import subprocess

from os import path
from termcolor import colored

OUTPUT = "Output/"
BASE_PATH = "../../"
UNITY_PATH = "Unity/"
UNITY_BASE = BASE_PATH + UNITY_PATH
UNITY_SRC = UNITY_BASE + "src/"


def parse_args():
    parser = argparse.ArgumentParser(description="Run CMSIS-NN unit tests.",
                                     epilog="Runs on all connected HW supported by Mbed.")
    parser.add_argument('--testdir', type=str, default='TESTRUN', help="prefix of output dir name")
    parser.add_argument('-s',
                        '--specific-test',
                        type=str,
                        default=None,
                        help="Run a specific test, e.g."
                        " -s TestCases/test_arm_avgpool_s8 (also this form will work: -s test_arm_avgpool_s8)."
                        " So basically the different options can be listed with:"
                        " ls -d TestCases/test_* -1")
    parser.add_argument('-c', '--compiler', type=str, default='GCC_ARM', choices=['GCC_ARM', 'ARMC6'])
    parser.add_argument('--download-and-generate-test-runners',
                        dest='download_and_generate',
                        action='store_true',
                        help="Just download Unity and generate test runners if needed")

    required_arguments = parser.add_argument_group('required named arguments')

    args = parser.parse_args()
    return args


def error_handler(code, text=None):
    print("Error: {}".format(text))
    sys.exit(code)


def detect_targets(targets):
    process = subprocess.Popen(['mbedls'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    print(process.stdout.readline().strip())
    while True:
        line = process.stdout.readline()
        print(line.strip())
        if not line:
            break
        if re.search(r"^\| ", line):
            words = (line.split('| '))
            target = {
                "model": words[1].strip(),
                "name": words[2].strip()[:-1].replace('[', '_'),
                "port": words[4].strip(),
                "tid": words[5].strip()
            }  # Target id can be used to filter out targets
            targets.append(target)
    return_code = process.poll()
    if return_code != 0:
        error_handler(return_code, 'RETURN CODE {}'.format(process.stderr.read()))


def run_command(command, error_msg=None, die=True):
    # TODO handle error:
    # cp: error writing '/media/mannil01/NODE_F411RE/TESTRUN_NUCLEO_F411RE_GCC_ARM.bin': No space left on device
    # https://os.mbed.com/questions/59636/STM-Nucleo-No-space-left-on-device-when-/

    # print(command)
    command_list = command.split(' ')
    process = subprocess.run(command_list)
    if die and process.returncode != 0:
        error_handler(process.returncode, error_msg)
    return process.returncode


def detect_architecture(target_name, target_json):
    arch = None
    try:
        with open(target_json, "r") as read_file:
            data = json.load(read_file)

            while ('core' not in data[target_name]):
                print(f"{target_name} inherits from {data[target_name]['inherits'][0]}")
                target_name = data[target_name]['inherits'][0]
            core = data[target_name]['core']

            if core:
                arch = core[:9]
                if core[:8] == 'Cortex-M':
                    return arch
            error_handler(168, 'Unsupported target: {} with architecture: {}'.format(target_name, arch))
    except Exception as e:
        error_handler(167, e)

    return arch


def test_target(target, args, main_test):
    result = 3
    compiler = args.compiler
    cmsis_nn_path = "../../../../"
    target_name = target['name']
    target_model = target['model']
    unittestframework = 'UNITY_UNITTEST'

    dir_name = OUTPUT + args.testdir + '_' + unittestframework + '_' + target_name + '_' + compiler

    os.makedirs(dir_name, exist_ok=True)
    start_dir = os.getcwd()
    os.chdir(dir_name)

    try:
        target_json = 'mbed-os/targets/targets.json'
        mbed_path = BASE_PATH + 'Mbed/'

        if not path.exists("mbed-os.lib"):
            print("Initializing mbed in {}".format(os.getcwd()))
            run_command(f'cp -a {mbed_path}. .')
            run_command('mbed deploy')
            run_command(f'rm -rf mbed-os/TESTS' + ' mbed-os/UNITTESTS' + ' mbed-os/docker_images' + ' mbed-os/docs' +
                        ' mbed-os/extern' + ' mbed-os/features')
        arch = detect_architecture(target_model, target_json)

        print("----------------------------------------------------------------")
        print("Running {} on {} target: {} with compiler: {} in directory: {} test: {}\n".format(
            unittestframework, arch, target_name, compiler, os.getcwd(), main_test))

        die = False
        flash_error_msg = 'failed to flash'
        mbed_command = "compile"
        test = ''
        additional_options = ' --source ' + BASE_PATH + main_test + \
                             ' --source ' + UNITY_SRC + \
                             ' --profile ' + mbed_path + 'release.json' + \
                             ' -f'

        result = run_command("mbed {} -v -m ".format(mbed_command) + target_model + ' -t ' + compiler + test +
                             ' --source .'
                             ' --source ' + BASE_PATH + 'TestCases/Utils/'
                             ' --source ' + cmsis_nn_path + 'Include/'
                             ' --source ' + cmsis_nn_path + 'Source/ConvolutionFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/PoolingFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/NNSupportFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/FullyConnectedFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/SoftmaxFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/SVDFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/BasicMathFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/ActivationFunctions/'
                             ' --source ' + cmsis_nn_path + 'Source/LSTMFunctions/' + additional_options,
                             flash_error_msg,
                             die=die)

    except Exception as e:
        error_handler(166, e)

    os.chdir(start_dir)
    return result


def read_serial_port(ser, inputQueue, stop):
    while True:
        if stop():
            break
        line = ser.readline()
        inputQueue.put(line.decode('latin-1').strip())


def test_target_with_unity(target, args, main_test):
    port = target['port']
    stop_thread = False
    baudrate = 9600
    timeout = 30
    inputQueue = queue.Queue()
    tests = copy.deepcopy(target["tests"])

    try:
        ser = serial.Serial(port, baudrate, timeout=timeout)
    except Exception as e:
        error_handler(169, "serial exception: {}".format(e))

    # Clear read buffer
    time.sleep(0.1)  # Workaround in response to: open() returns before port is ready
    ser.reset_input_buffer()

    serial_thread = threading.Thread(target=read_serial_port, args=(ser, inputQueue, lambda: stop_thread), daemon=True)
    serial_thread.start()

    test_target(target, args, main_test)

    start_time = time.time()
    while time.time() < start_time + timeout:
        if inputQueue.qsize() > 0:
            str_line = inputQueue.get()
            print(str_line)
            test = None
            try:
                test = str_line.split(':')[2]
                test_result = ':'.join(str_line.split(':')[2:4])
            except IndexError:
                pass
            if test in tests:
                tests.remove(test)
                target[test]["tested"] = True
                if test_result == test + ':PASS':
                    target[test]["pass"] = True
            if len(tests) == 0:
                break

    stop_thread = True
    serial_thread.join()
    ser.close()


def print_summary(targets):
    """
    Return 0 if all test passed
    Return 1 if all test completed but one or more failed
    Return 2 if one or more tests did not complete or was not detected
    """
    passed = 0
    failed = 0
    tested = 0
    expected = 0
    return_code = 3
    verdict_pass = colored('[ PASSED ]', 'green')
    verdict_fail = colored('[ FAILED ]', 'red')
    verdict_error = colored('[ ERROR ]', 'red')

    print("-----------------------------------------------------------------------------------------------------------")

    # Find all passed and failed
    for target in targets:
        for test in target["tests"]:
            expected += 1
            if target[test]["tested"]:
                tested += 1
            if target[test]["pass"]:
                passed += 1
            else:
                failed += 1

    if tested != expected:
        print("{} Not all tests found".format(verdict_error))
        print("{} Expected: {} Actual: {}".format(verdict_error, expected, tested))
        return_code = 2
    elif tested == passed:
        return_code = 0
    else:
        return_code = 1

    # print all test cases
    sorted_tc = []
    for target in targets:
        for test in target["tests"]:
            if not target[test]["tested"]:
                tc_verdict = verdict_error
            elif target[test]["pass"]:
                tc_verdict = verdict_pass
            else:
                tc_verdict = verdict_fail
            sorted_tc.append("{} {}: {}".format(tc_verdict, target["name"], test))
    sorted_tc.sort()
    for tc in sorted_tc:
        print(tc)

    total = 0
    if (passed > 0):
        total = passed / expected
    if (total == 1.0):
        verdict = verdict_pass
    else:
        verdict = verdict_fail
    print("{} Summary: {} tests in total passed on {} target(s) ({})".format(verdict, passed, len(targets),
                                                                             ', '.join([t['name'] for t in targets])))
    print("{} {:.0f}% tests passed, {} tests failed out of {}".format(verdict, total * 100, failed, expected))

    return return_code


def test_targets(args):
    """
    Return 0 if successful
    Return 3 if no targets are detected
    Return 4 if no tests are found
    """
    result = 0
    targets = []
    main_tests = []

    if not args.download_and_generate:
        detect_targets(targets)
        if len(targets) == 0:
            print("No targets detected!")
            return 3

    download_unity()

    if not parse_tests(targets, main_tests, args.specific_test):
        print("No tests found?!")
        return 4

    if args.download_and_generate:
        return result

    for target in targets:
        for tst in main_tests:
            test_target_with_unity(target, args, tst)

    result = print_summary(targets)

    return result


def download_unity(force=False):
    unity_dir = UNITY_PATH
    unity_src = unity_dir + "src/"
    process = subprocess.run(['mktemp'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    download_dir = process.stdout.strip()
    run_command("rm -f {}".format(download_dir))
    download_dir += '/'

    # Check if already downloaded
    if not force and path.isdir(unity_dir) and path.isfile(unity_src + "unity.c") and path.isfile(unity_src +
                                                                                                  "unity.h"):
        return

    if path.isdir(download_dir):
        shutil.rmtree(download_dir)
    if path.isdir(unity_dir):
        shutil.rmtree(unity_dir)
    os.mkdir(unity_dir)
    os.makedirs(download_dir, exist_ok=False)
    current_dir = os.getcwd()
    os.chdir(download_dir)
    process = subprocess.Popen(
        'curl -LJ https://api.github.com/repos/ThrowTheSwitch/Unity/tarball/v2.5.0 --output unity_tarball.tar.gz'.split(
        ),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True)
    for line in process.stderr:
        print(line.strip())
    print()
    for line in process.stdout:
        pass
    if not line:
        error_handler(171)
    downloaded_file = download_dir + "unity_tarball.tar.gz"
    os.chdir(current_dir)
    try:
        filename_base = downloaded_file.split('-')[0]
    except IndexError as e:
        error_handler(174, e)
    if not filename_base:
        error_handler(175)
    run_command("tar xzf " + downloaded_file + " -C " + unity_dir + " --strip-components=1")
    os.chdir(current_dir)

    # Cleanup
    shutil.rmtree(download_dir)


def parse_generated_test_runner(test_runner):
    parsed_functions = ['setUp', 'tearDown', 'resetTest', 'verifyTest']

    def is_func_to_parse(func):
        for f in parsed_functions:
            if f in func:
                return True
        return False

    with open(test_runner, "r") as f:
        lines = f.readlines()
    with open(test_runner, "w") as f:
        for line in lines:
            sline = line.strip('\n')
            if not re.search(r"\(void\);", sline):
                f.write(line)
            else:
                if not is_func_to_parse(sline):
                    f.write(line)


def parse_tests(targets, main_tests, specific_test=None):
    """
    Generate test runners, extract and return path to unit test(s).
    Also parse generated test runners to avoid warning: redundant redeclaration.
    Return True if successful.
    """
    test_found = False
    directory = 'TestCases'

    if specific_test and '/' in specific_test:
        specific_test = specific_test.strip(directory).replace('/', '')

    for dir in next(os.walk(directory))[1]:
        if re.search(r'test_arm', dir):
            if specific_test and dir != specific_test:
                continue
            test_found = True
            testpath = directory + '/' + dir + '/Unity/'
            ut_test_file = None
            for content in os.listdir(testpath):
                if re.search(r'unity_test_arm', content):
                    ut_test_file = content
            if ut_test_file is None:
                print("Warning: invalid path: ", testpath)
                continue
            main_tests.append(testpath)
            ut_test_file_runner = path.splitext(ut_test_file)[0] + '_runner' + path.splitext(ut_test_file)[1]
            test_code = testpath + ut_test_file
            test_runner_path = testpath + 'TestRunner/'
            if not os.path.exists(test_runner_path):
                os.mkdir(test_runner_path)
            test_runner = test_runner_path + ut_test_file_runner
            for old_files in glob.glob(test_runner_path + '/*'):
                if not old_files.endswith('readme.txt'):
                    os.remove(old_files)

            # Generate test runners
            run_command('ruby ' + UNITY_PATH + 'auto/generate_test_runner.rb ' + test_code + ' ' + test_runner)
            test_found = parse_test(test_runner, targets)
            if not test_found:
                return False

            parse_generated_test_runner(test_runner)

    if not test_found:
        return False
    return True


def parse_test(test_runner, targets):
    tests_found = False

    # Get list of tests
    try:
        read_file = open(test_runner, "r")
    except IOError as e:
        error_handler(170, e)
    else:
        with read_file:
            for line in read_file:
                if not line:
                    break
                if re.search(r"  run_test\(", line) and len(line.strip().split(',')) == 3:
                    function = line.strip().split(',')[0].split('(')[1]
                    tests_found = True
                    for target in targets:
                        if 'tests' not in target.keys():
                            target['tests'] = []
                        target["tests"].append(function)
                        target[function] = {}
                        target[function]["pass"] = False
                        target[function]["tested"] = False
    return tests_found


if __name__ == '__main__':
    args = parse_args()
    sys.exit(test_targets(args))
